Skip to content

Dev/fuyuajin/maxtext backend test#557

Open
amd-fuyuajin wants to merge 21 commits intomainfrom
dev/fuyuajin/maxtext-backend-test
Open

Dev/fuyuajin/maxtext backend test#557
amd-fuyuajin wants to merge 21 commits intomainfrom
dev/fuyuajin/maxtext-backend-test

Conversation

@amd-fuyuajin
Copy link

No description provided.

@yeandy yeandy requested review from alfuyao1986 and yeandy February 18, 2026 20:35
@amd-fuyuajin amd-fuyuajin removed the request for review from yeandy February 18, 2026 20:35
@yeandy yeandy self-requested a review February 18, 2026 20:35
Copy link

@yeandy yeandy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering how many of these files we need?

  • primus/backends/maxtext/input_pipeline/_hf_data_processing.py
  • primus/backends/maxtext/input_pipeline/custom_packed_batch.py (I see this is deleted)
  • primus/backends/maxtext/layers/attention_op.py
  • primus/backends/maxtext/layers/attentions.py (I see this is deleted)
  • primus/backends/maxtext/metric_logger.py
  • primus/backends/maxtext/train.py
  • primus/backends/maxtext/train_utils.py

I think they were added in the past for the purposes of patching. @amd-fuyuajin do you know if these are getting patched into the MaxText codebase when you run the training? Even if it is, it might be the same code as what is found in rocm/jax-training:maxtext-v26.1 actually. @llying-001 might know best.

@llying-001
Copy link
Contributor

llying-001 commented Feb 24, 2026

I'm wondering how many of these files we need?

  • primus/backends/maxtext/input_pipeline/_hf_data_processing.py
  • primus/backends/maxtext/input_pipeline/custom_packed_batch.py (I see this is deleted)
  • primus/backends/maxtext/layers/attention_op.py
  • primus/backends/maxtext/layers/attentions.py (I see this is deleted)
  • primus/backends/maxtext/metric_logger.py
  • primus/backends/maxtext/train.py
  • primus/backends/maxtext/train_utils.py

I think they were added in the past for the purposes of patching. @amd-fuyuajin do you know if these are getting patched into the MaxText codebase when you run the training? Even if it is, it might be the same code as what is found in rocm/jax-training:maxtext-v26.1 actually. @llying-001 might know best.

I updated these files in the Primus repo to stay aligned with the yeandy/update-patches-scaling-patch-v2-checkpoint-restore branch in ROCm/maxtext.
The third_party/maxtext in Primus comes directly from a commit of https://github.com/AI-Hypercomputer/maxtext. This commit is effectively identical to the corresponding commit in the ROCm/maxtext main branch (i.e., there are no functional code differences).
Therefore. the diff between https://github.com/ROCm/maxtext/tree/yeandy/update-patches-scaling-patch-v2-checkpoint-restore and the main branch in ROCm/maxtext serves as the reference for which patch filed need to be added on the Primus side. @amd-fuyuajin @yeandy

llying-001 and others added 16 commits February 25, 2026 03:09
- Add timestamp to log filenames to prevent overwriting across runs
- Move tee logging outside the inline script to capture consolidated multi-node output in a single log file
- Make --nodelist conditional via NODE_LIST env variable
- set TF_CPP_MIN_LOG_LEVEL=2. Without this setting, error occurs at the end when all training steps complete.
- XLA_FLAGS is case sensitive. Corrected a few values.
- detect backend framework in `primus-cli-direct.sh`. Install JAX
  dependencies
- If using AINIC (setting USING_AINIC=1), `03_enable_ainic.sh` will run.
  The `LD_LIBRARY_PATH` is modified to make sure libraries are correctly
  loaded for JAX/MaxText.
- Set XLA_PYTHON_CLIENT_MEM_FRACTION=.93 to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training
- Corrected some XLA_FLAGS. It is case sensitive. Values `true` and
  `false` do not need to be capitalized.
- set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training

Here is an example to launch JAX/MaxText traing on two nodes.
`./primus-cli --config runner/maxtext-test.yaml slurm srun -N 2 -- train
pretrain --config
examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml`
Problem: when apt install linux-headers-"$(uname -r)", it was resolved
to wrong version number on some nodes, and caused "package not found"
error.
Solution: remove it from the package install list. It does not affect
the performance.
1. added examples for using AINIC in training
2. added more examples for running preflight
3. updated arguments format for benchmark gemm command. The script was
   changed, but document was not updated.
@zhaoh27 zhaoh27 force-pushed the dev/fuyuajin/maxtext-backend-test branch from 2e31891 to 095b267 Compare February 25, 2026 03:15
@llying-001 llying-001 marked this pull request as ready for review February 25, 2026 03:22
Copilot AI review requested due to automatic review settings February 25, 2026 03:22
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds comprehensive support for JAX/MaxText backend testing and multi-node training capabilities, including AINIC network integration, improved checkpointing, and various model architecture enhancements.

Changes:

  • Updated MaxText submodule to a newer commit
  • Added AINIC configuration support with proper environment variable setup and library path ordering
  • Enhanced MaxText backend with improved checkpointing, attention mechanisms, and decoder layer implementations
  • Refactored dependency installation to detect framework type and install appropriate requirements

Reviewed changes

Copilot reviewed 34 out of 35 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
third_party/maxtext Updated MaxText submodule reference to newer commit
runner/use_ainic.yaml New configuration file for AINIC network setup with container options
runner/primus-cli-direct.sh Added framework detection logic to install correct dependencies (JAX vs PyTorch)
runner/helpers/hooks/train/pretrain/maxtext/prepare.py Removed problematic linux-headers package, adjusted memory limits and XLA flags
runner/helpers/hooks/03_enable_ainic.sh Fixed LD_LIBRARY_PATH ordering to append instead of prepend paths
runner/.primus.yaml Uncommented InfiniBand device for AINIC support
requirements-jax.txt Simplified to core dependencies only
primus/pretrain.py Enhanced MaxText path detection to support src subdirectory
primus/modules/trainer/maxtext/pre_trainer.py Extended patching to include initialization, checkpointing, config types, and decoder layers
primus/configs/modules/maxtext/trainer_base.yaml Updated configuration with new parameters and removed deprecated options
primus/configs/models/maxtext/llama3.1_405B.yaml New model configuration for Llama 3.1 405B
primus/backends/maxtext/train_utils.py Refactored emergency checkpoint logic and updated to use max_num_checkpoints_to_keep
primus/backends/maxtext/train.py Major refactor with barrier synchronization, improved error handling, and new training features
primus/backends/maxtext/metric_logger.py Updated to use MetadataKey enum constants
primus/backends/maxtext/max_utils.py Added JAX distributed initialization functions for GPU/CPU/TPU
primus/backends/maxtext/layers/moe.py Updated MoE layer to pass bias parameters
primus/backends/maxtext/layers/mixtral.py New Primus-specific Mixtral decoder layer implementation
primus/backends/maxtext/layers/mistral.py New Primus-specific Mistral decoder layer implementation
primus/backends/maxtext/layers/llama2.py New Primus-specific Llama2 decoder layer implementation
primus/backends/maxtext/layers/gemma2.py New Primus-specific Gemma2 decoder layer implementation
primus/backends/maxtext/layers/gemma.py New Primus-specific Gemma decoder layer implementation
primus/backends/maxtext/layers/attentions.py Removed entire attention implementation file
primus/backends/maxtext/layers/attention_op.py Enhanced CUDNN Flash Attention with packing and context parallelism support
primus/backends/maxtext/input_pipeline/custom_packed_batch.py Removed custom packing implementation
primus/backends/maxtext/input_pipeline/_hf_data_processing.py Updated to use grain's native packing and added instruction format conversion
primus/backends/maxtext/configs/types.py New Primus-specific MaxText config with WandB and Turbo support
primus/backends/maxtext/checkpointing.py Added comprehensive checkpoint loading logic with single replica support
examples/run_slurm_pretrain.sh Added NODE_LIST support and timestamped log files
examples/run_pretrain.sh Reorganized AINIC configuration and updated XLA flags
examples/run_local_pretrain.sh Updated default Docker image to maxtext-v26.1
examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml Reduced batch size from 12 to 11
examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml New training configuration for Llama 3.1 405B model
examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml Updated remat policy
docs/cli/PRIMUS-CLI-GUIDE.md Updated documentation with AINIC configuration examples and corrected command syntax
Comments suppressed due to low confidence (4)

runner/primus-cli-direct.sh:1

  • Array index arithmetic should use proper bash syntax. The expression $((i+1)) correctly increments i, but when used inside array subscript it should be written as ${args[i+1]} without the extra parentheses, or the current form needs validation that i+1 is within array bounds before access.
    runner/primus-cli-direct.sh:1
  • Python code embedded in bash script should properly close file handles. The open('$cfg_path') should be wrapped in a context manager using with open('$cfg_path') as f: cfg = yaml.safe_load(f) to ensure the file is properly closed even if an exception occurs.
    primus/backends/maxtext/max_utils.py:1
  • Operator precedence issue: the condition mixes or and and without parentheses. Due to operator precedence, this evaluates as (self.wandb_save_dir is None) or (self.wandb_save_dir == '' and self.base_output_directory), which may not be the intended logic. Add explicit parentheses: if (self.wandb_save_dir is None or self.wandb_save_dir == '') and self.base_output_directory:
###############################################################################

primus/backends/maxtext/max_utils.py:1

  • Same operator precedence issue as above. Should be: if (self.wandb_exp_name is None or self.wandb_exp_name == '') and self.run_name:
###############################################################################

accept copilot commit suggestion

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Comment on lines +179 to +220
if [ "${BACKEND:-}" == "MaxText" ]; then
# ------- RCCL/NCCL IB Tuning -------
export IONIC_LOCKFREE=all
export NCCL_GDR_COPY_ENABLE=1
export NCCL_GDR_FLUSH_DISABLE=1
export NCCL_IB_ECE_ENABLE=0
export NCCL_IB_FIFO_TC=184
export NCCL_IB_GID_INDEX=1
export NCCL_IB_PCI_RELAXED_ORDERING=1
export NCCL_IB_TC=96
export NCCL_IB_USE_INLINE=1
export NCCL_IGNORE_CPU_AFFINITY=1
export NCCL_PXN_DISABLE=0
export NET_OPTIONAL_RECV_COMPLETION=1
export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0
export RCCL_LL128_FORCE_ENABLE=1
else
export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"}
export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"}
export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"}
export NCCL_NET_PLUGIN=librccl-anp.so

LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR"
LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR"
LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR"

# unset NCCL_IB_GID_INDEX
export NCCL_IB_GID_INDEX=1
# export NCCL_IB_ROCE_VERSION_NUM=2
export NCCL_MAX_P2P_CHANNELS=56
export NCCL_IB_TC=104
export NCCL_IB_FIFO_TC=192
export NET_OPTIONAL_RECV_COMPLETION=1
export NCCL_IB_USE_INLINE=1
export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0
export NCCL_GDR_FLUSH_DISABLE=1
export NCCL_DMABUF_ENABLE=0
export NCCL_IGNORE_CPU_AFFINITY=1
export NCCL_IB_QPS_PER_CONNECTION=1

export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH
fi
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need to have different flags (NCCL_IB_TC, NCCL_IB_FIFO_TC) when using MaxText backend or not using MaxText backend? I think these flags are more related to cluster settings, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@llying-001 can explain this better. I did not change any of this part.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need to have different flags (NCCL_IB_TC, NCCL_IB_FIFO_TC) when using MaxText backend or not using MaxText backend? I think these flags are more related to cluster settings, right?

I extracted these env flags for MaxText backend from https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/jax_maxtext_multinode_benchmark.sh#L305. They are actually related to the cluster instead of backend. Are the env flags in jax_maxtext_multinode_benchmark.sh configured for Vultr cluster? @yeandy
For the Megatron/Titan backend, which cluster are the env flags in run_pretrain.sh configured for? @zhenhuang12
It would be great if we could unify them.

apt install jq dpkg-dev kmod xz-utils -y
apt install libibverbs-dev ibverbs-utils infiniband-diags -y
apt install rdma-core librdmacm-dev libibverbs-dev libibumad-dev -y
LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText Done =========="
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not for JAX/MaxText libraries per-se, but rather to add missing dependencies not found in the public docker (like rocm/jax-training:maxtext-v26.1), right? @amd-fuyuajin

We don't need to do this for megatron or torchtitan jobs? Or is this already installed in those dockers? @wenxie-amd

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These packages are mainly related to InfiniBand/RDMA libraries. I see they are only installed when NNODES > 1 (line 440). They probably provide networking stack for distributed training. Again, @llying-001 added this and can explain better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these packages are dependencies required for REBUILD_BNXT that are missing in the public JAX docker image (e.g., rocm/jax-training:maxtext-v26.1), but they are already installed in the Torch docker image (e.g., rocm/primus:v26.1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants